# thanks to https://github.com/prajwalsingh/EEGStyleGAN-ADA
import os
import numpy as np
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
from natsort import natsorted
import os
device = "cuda"
from glob import glob
from pytorch_metric_learning import miners, losses
from pytorch_metric_learning import regularizers
from braindecode.augmentation import FTSurrogate, SmoothTimeMask
import random
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment as linear_assignment
from tqdm import tqdm
    
import pickle
from transformers import ViTModel, ViTConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange


from torch.utils.data import TensorDataset, DataLoader, Dataset

import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange


Sur = FTSurrogate(probability=0.5, phase_noise_magnitude=1).to(device)
mask = SmoothTimeMask(probability=0.5, mask_len_samples=16).to(device)


class K_means:
    def __init__(self, n_clusters=40, random_state=45):
        self.n_clusters = n_clusters
        self.random_state = random_state

    def transform(self, embed, gt_labels):
        pred_labels = KMeans(n_clusters=self.n_clusters, random_state=self.random_state).fit_predict(embed)
        score       = self.cluster_acc(gt_labels, pred_labels)
        return score

    # Thanks to: https://github.com/k-han/DTC/blob/master/utils/util.py
class K_means:
    def __init__(self, n_clusters=40, random_state=45):
        self.n_clusters = n_clusters
        self.random_state = random_state
    def transform(self, embed, gt_labels):
        pred_labels = KMeans(n_clusters=self.n_clusters, random_state=self.random_state).fit_predict(embed)
        accuracy = self.cluster_metrics(gt_labels, pred_labels)
        return accuracy

    def cluster_metrics(self, y_true, y_pred):
        """
        Calculate clustering accuracy and precision. Require scikit-learn installed

        Arguments:
          y_true: true labels, numpy.array with shape `(n_samples,)`
          y_pred: predicted labels, numpy.array with shape `(n_samples,)`

        Returns:
          accuracy: float, in [0,1]
          precision: float, in [0,1]
        """
        y_true = y_true.astype(np.int64)
        assert y_pred.size == y_true.size
        D = max(y_pred.max(), y_true.max()) + 1
        w = np.zeros((D, D), dtype=np.int64)
        for i in range(y_pred.size):
            w[y_pred[i], y_true[i]] += 1

        # Calculate accuracy (same as before)
        ind = linear_assignment(w.max() - w)
        accuracy = sum([w[i, j] for i, j in zip(*ind)]) * 1.0 / y_pred.size

        # Calculate precision for each cluster
        precision = np.zeros(D)
        for cluster in range(D):
              if np.sum(w[cluster, :]) > 0:  # Avoid division by zero
                precision[cluster] = w[cluster, cluster] / np.sum(w[cluster, :])

        overall_precision = np.mean(precision)

        return accuracy
    
    


def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

# patch dropout

class PatchDropout(nn.Module):
    def __init__(self, prob):
        super().__init__()
        assert 0 <= prob < 1.
        self.prob = prob

    def forward(self, x):
        if not self.training or self.prob == 0.:
            return x

        b, n, _, device = *x.shape, x.device

        batch_indices = torch.arange(b, device = device)
        batch_indices = rearrange(batch_indices, '... -> ... 1')
        num_patches_keep = max(1, int(n * (1 - self.prob)))
        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

        return x[batch_indices, patch_indices_keep]

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class SimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim)
        )

        self.patch_dropout = PatchDropout(patch_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.to_latent = nn.Identity()
        self.linear_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        *_, h, w, dtype = *img.shape, img.dtype
        img = img.unsqueeze(1)
        x = self.to_patch_embedding(img)
        pe = posemb_sincos_2d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.patch_dropout(x)

        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return x



class EEGDataset(Dataset):
    def __init__(self, eegs, labels):
        self.eegs         = eegs
        self.labels       = labels

    def __getitem__(self, index):
        eeg    = self.eegs[index]
        norm   = torch.max(eeg) / 2.0
        eeg    = (eeg - norm)/ norm
        label  = self.labels[index]
        return eeg, label
    def __len__(self):
        return len(self.eegs)
    
    
    
def train(epoch, model, optimizer, loss_fn, miner, train_data, train_dataloader):

    running_loss      = []
    eeg_featvec_proj  = np.array([])
    labels_array      = np.array([])


    tq = tqdm(train_dataloader)
    for _, (eeg, labels) in enumerate(tq, start=1):
        eeg    = eeg.to(device)
        labels = labels.to(device)
        eeg = eeg.squeeze(1)
        optimizer.zero_grad()
        
        sur_params = Sur.get_augmentation_params(eeg, labels)
        eeg, labels = Sur.operation(eeg, labels,
                                    phase_noise_magnitude=sur_params['phase_noise_magnitude'],
                                    channel_indep=sur_params['channel_indep'],
                                    random_state = sur_params['random_state'])
        

        #eeg = eeg.unsqueeze(1)
        x_proj = model(eeg)
        hard_pairs = miner(x_proj, labels)
        loss       = loss_fn(x_proj, labels, hard_pairs)
        loss.backward()
        optimizer.step()

        running_loss = running_loss + [loss.detach().cpu().numpy()]

        tq.set_description('Train:[{}, {:0.3f}]'.format(epoch, np.mean(running_loss)))

    if (epoch%5) == 0:
        model.eval()
        for batch_idx, (eeg, labels) in enumerate(tqdm(train_dataloader)):
            eeg, labels = eeg.to(device), labels.to(device)
            eeg = eeg.squeeze(1)
            with torch.no_grad():
                x_proj = model(eeg)
            eeg_featvec_proj = np.concatenate((eeg_featvec_proj, x_proj.cpu().detach().numpy()), axis=0) if eeg_featvec_proj.size else x_proj.cpu().detach().numpy()
            labels_array     = np.concatenate((labels_array, labels.cpu().detach().numpy()), axis=0) if labels_array.size else labels.cpu().detach().numpy()
        num_clusters   = 10
        k_means        = K_means(n_clusters=num_clusters)
        clustering_acc_proj = k_means.transform(eeg_featvec_proj, labels_array)
        print("[Epoch: {}, Train KMeans score Proj: {}]".format(epoch, clustering_acc_proj))
        model.train()
    return running_loss


def validation(epoch, model, optimizer, loss_fn, miner, train_data, val_dataloader):

    running_loss      = []
    eeg_featvec_proj  = np.array([])
    labels_array      = np.array([])

    tq = tqdm(val_dataloader)
    for batch_idx, (eeg, labels) in enumerate(tq, start=1):
        eeg, labels = eeg.to(device), labels.to(device)
        eeg = eeg.squeeze(1)
        model.eval()
        with torch.no_grad():
            x_proj = model(eeg)
            hard_pairs = miner(x_proj, labels)
            loss       = loss_fn(x_proj, labels, hard_pairs)
            running_loss = running_loss + [loss.detach().cpu().numpy()]

        tq.set_description('Val:[{}, {:0.3f}]'.format(epoch, np.mean(running_loss)))

        eeg_featvec_proj = np.concatenate((eeg_featvec_proj, x_proj.cpu().detach().numpy()), axis=0) if eeg_featvec_proj.size else x_proj.cpu().detach().numpy()
        labels_array     = np.concatenate((labels_array, labels.cpu().detach().numpy()), axis=0) if labels_array.size else labels.cpu().detach().numpy()

    ### compute k-means score and Umap score on the text and image embeddings
    num_clusters   = 10
    print(eeg_featvec_proj.shape)
    k_means        = K_means(n_clusters=num_clusters)
    clustering_acc_proj = k_means.transform(eeg_featvec_proj, labels_array)
    print("[Epoch: {}, Val KMeans score Proj: {}]".format(epoch, clustering_acc_proj))
    model.train()
    return running_loss, clustering_acc_proj


def compute_fft(data):
    return np.abs(np.fft.fft(data, axis=0))


if __name__ == '__main__':
    
    seed = 45
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    base_path = ""

    with open(base_path + "thoughtviz/eeg/image/data.pkl", 'rb') as file:
        data = pickle.load(file, encoding='latin1')
        train_X = data['x_train']
        train_Y = data['y_train']
        val_X = data['x_test']
        val_Y = data['y_test']
    num_train_samples = train_X.shape[0]
    indices = np.random.choice(num_train_samples, 4000, replace=False)
    selected_train_X = train_X[indices]
    selected_train_Y = train_Y[indices]

    # Remove the selected samples from the training set
    train_X = np.delete(train_X, indices, axis=0)
    train_Y = np.delete(train_Y, indices, axis=0)

    # Add the selected samples to the validation set
    val_X = np.append(val_X, selected_train_X, axis=0)
    val_Y = np.append(val_Y, selected_train_Y, axis=0)

    x_train_eeg = []
    x_train_image = []
    labels = []
    x_train_subject=[]

    # ## hyperparameters
    batch_size     = 256
    EPOCHS         = 2048

    class_labels   = {}
    label_count    = 0
    print(train_X.shape[0])
    for idx in range(train_X.shape[0]):
        eeg_data = np.squeeze(np.transpose(train_X[idx], (2, 0, 1)), axis=0)

        x_train_eeg.append(eeg_data)
        x_train_image.append(np.zeros(shape=(2, 2)))
        x_train_subject.append(0)
        labels.append(np.argmax(train_Y[idx]))

    x_train_eeg   = np.array(x_train_eeg)
    x_train_image = np.array(x_train_image)
    train_labels  = np.array(labels)
    x_train_subject = np.array(x_train_subject)

    print(x_train_eeg.shape, x_train_image.shape, train_labels.shape, x_train_subject.shape)
    print('Total number of classes: {}'.format(len(np.unique(train_labels))), np.unique(train_labels))

    x_train_eeg   = torch.from_numpy(x_train_eeg).float()#.to(device)
    x_train_eeg = x_train_eeg.unsqueeze(1)
    x_train_image = torch.from_numpy(x_train_image).float()#.to(device)
    train_labels  = torch.from_numpy(train_labels).long()#.to(device)
    x_train_subject  = torch.from_numpy(x_train_subject).long()#.to(device)
    
    torch.save(x_train_eeg, "train_data_eeg_thought.pt")
    torch.save(train_labels, "train_data_labels_thought.pt")
    
    train_data       = EEGDataset(x_train_eeg, train_labels)
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=False, drop_last=True)

    
    ## Validation data
    x_val_eeg   = []
    x_val_image = []
    label_val   = []
    x_val_subject = []

    for idx in range(val_X.shape[0]):
        eeg_data = np.squeeze(np.transpose(val_X[idx], (2, 0, 1)), axis=0)

        x_val_eeg.append(eeg_data)
        x_val_image.append(np.zeros(shape=(2, 2)))
        x_val_subject.append(0.0)
        label_val.append(np.argmax(val_Y[idx]))

    x_val_eeg   = np.array(x_val_eeg)
    x_val_image = np.array(x_val_image)
    label_val   = np.array(label_val)
    x_val_subject = np.array(x_val_subject)

    print(x_val_eeg.shape, x_val_image.shape, label_val.shape, x_val_subject.shape)
    print('Total number of classes: {}'.format(len(np.unique(label_val))), np.unique(label_val))

    x_val_eeg   = torch.from_numpy(x_val_eeg).float().to(device)
    x_val_eeg = x_val_eeg.unsqueeze(1)
    x_val_image = torch.from_numpy(x_val_image).float()#.to(device)
    label_val   = torch.from_numpy(label_val).long().to(device)
    x_val_subject  = torch.from_numpy(x_val_subject).long()#.to(device)


    torch.save(x_val_eeg, "test_data_eeg_thought.pt")
    torch.save(label_val, "test_data_labels_thought.pt")
    
    
    val_data       = EEGDataset(x_val_eeg, label_val)
    torch.save( val_data, "val_data.pt")
    val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, pin_memory=False, drop_last=True)
    
    model = SimpleViT(
            image_size = (14, 32),
            patch_size = (2, 2),
            num_classes = 40,
            dim = 256,
            depth = 4,
            dim_head=16,
            heads = 16,
            mlp_dim = 64,
            channels = 1
        )
    model     = torch.nn.DataParallel(model).to(device)
    optimizer = torch.optim.Adam(\
                                    list(model.parameters()),\
                                    lr=1e-3,\
                                    betas=(0.9, 0.999)
                                )


    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=400, eta_min=0, last_epoch=-1)

        




    START_EPOCH = 0
    pre = False
    if pre:
        ckpt_path  = '/home/ubuntu/bestckpt/eegfeat_all_0.5954861111111112.pth'
        checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        START_EPOCH = checkpoint['epoch']
        print('Loading checkpoint from previous epoch: {}'.format(START_EPOCH))
        START_EPOCH += 1

        miner   = miners.MultiSimilarityMiner()
        loss_fn = losses.TripletMarginLoss()

        best_val_acc   = 0.0
        best_val_epoch = 0

        for epoch in range(START_EPOCH, EPOCHS):
            running_train_loss = train(epoch, model, optimizer, loss_fn, miner, train_data, train_dataloader)
            running_val_loss, val_acc   = validation(epoch, model, optimizer, loss_fn, miner, train_data, val_dataloader)
            scheduler.step()
            if best_val_acc < val_acc:
                best_val_acc   = val_acc
                best_val_epoch = epoch
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, 'bestckpt/eegfeat_{}.pth'.format('all', val_acc))
